Skip to main content

9.3 - Code - Implementing an Action Mask

This section provides the complete, updated code required to implement action masking in our DRL_RM_Bot environment. This is a critical upgrade that makes training the generalized agent feasible by focusing its exploration on valid actions only.

Implementation Workflow

  • Step 1: Install the sb3-contrib library.
  • Step 2: Update drl_rm_bot.py to support the new Dict observation space and generate the mask.
  • Step 3: Update train.py to use the MaskablePPO algorithm.

Step 1 - Install sb3-contrib

In your activated virtual environment, install the companion library:

pip install sb3-contrib

Step 2 - The Code: Updated drl_rm_bot.py with Masking Logic

This is the modified version of our DRL-RM environment. The most significant changes are in the DRL_RM_Env class, which now defines a Dict space with the required keys ("obs" and "action_mask"), and the DRL_RM_Bot class, which has a new method to generate the mask.

drl_rm_bot.py

import numpy as np
import gymnasium as gym
from gymnasium.spaces import Box, Dict, MultiDiscrete
import queue

from burnysc2.bot_ai import BotAI
from burnysc2.race import Terran
from sc2.units import Units
from sc2_gym_env import SC2GymEnv

# --- Environment Constants ---
MAX_UNITS = 50
NUM_UNIT_FEATURES = 5
NUM_ABILITIES = 2 # 0=Move, 1=Attack


class DRL_RM_Bot(BotAI):
"""
The DRL-RM bot, now updated to generate an action mask.
"""
# NOTE: _encode_state() and _decode_and_execute_action() are kept from the previous section.
# For brevity, only the new/changed methods are shown here.

def __init__(self, action_queue, obs_queue):
super().__init__()
self.action_queue = action_queue
self.obs_queue = obs_queue
self.race = Terran

def _get_action_mask(self) -> np.ndarray:
"""
Calculates a flat boolean mask for all possible actions.
An action is valid if its actor and target indices are within bounds
of the current number of units.
"""
my_units_len = len(self.units)
all_units_len = len(self.all_units)

# Create a 2D mask for valid actor-target pairs.
# A cell (i, j) is True if actor i and target j are valid units.
actor_mask = np.arange(MAX_UNITS) < my_units_len
target_mask = np.arange(MAX_UNITS * 2) < all_units_len
valid_pairs = np.logical_and(actor_mask[:, None], target_mask[None, :])

# Expand dimensions to (MAX_UNITS, 1, MAX_UNITS * 2).
expanded_mask = valid_pairs[:, None, :]

# Tile along the ability dimension, assuming all abilities are potentially valid.
# This creates the final 3D mask of shape (50, 2, 100).
final_mask = np.tile(expanded_mask, (1, NUM_ABILITIES, 1))

# Flatten the 3D mask into a 1D vector for the agent.
return final_mask.flatten()

async def on_step(self, iteration: int):
# The main loop now follows a cleaner Observe -> Act cycle.
if iteration % 8 == 0:
# 1. OBSERVE: Get the current state and generate the observation and mask.
observation = self._encode_state()
action_mask = self._get_action_mask()

# The observation is now a dictionary with "obs" and "action_mask" keys.
obs_dict = {
"obs": observation,
"action_mask": action_mask
}

terminated = self.townhalls.amount == 0

# 2. SEND OBSERVATION: Send the data to the agent and wait for an action.
self.obs_queue.put((obs_dict, 0.1, terminated, False, {}))

if terminated:
await self.client.leave()
return

# 3. ACT: Get the (now guaranteed valid) action from the agent and execute it.
try:
action = self.action_queue.get_nowait()
await self._decode_and_execute_action(action)
except queue.Empty:
pass


class DRL_RM_Env(SC2GymEnv):
"""The Gymnasium Wrapper, updated to use a Dict observation space."""
def __init__(self):
super().__init__(bot_class=DRL_RM_Bot, map_name="AcropolisLE")

self.action_space = MultiDiscrete([MAX_UNITS, NUM_ABILITIES, MAX_UNITS * 2])

# The observation space must now be a Dict space for MaskablePPO.
self.observation_space = Dict({
# The key "obs" holds our original unit feature matrix.
"obs": Box(low=0, high=1, shape=(MAX_UNITS, NUM_UNIT_FEATURES), dtype=np.float32),
# The key "action_mask" holds the boolean mask.
"action_mask": Box(low=0, high=1, shape=(self.action_space.nvec.prod(),), dtype=bool)
})

Step 3 - The Code: Updated train.py

Your training script requires a single, simple change: importing and using MaskablePPO. No other changes are needed, as MaskablePPO is designed to automatically find the "obs" and "action_mask" keys in the environment's observation space.

# train.py
import multiprocessing as mp
# Import from sb3_contrib instead of stable_baselines3
from sb3_contrib import MaskablePPO
from drl_rm_bot import DRL_RM_Env

def main():
env = DRL_RM_Env()

# Use the MaskablePPO class
model = MaskablePPO("MlpPolicy", env, verbose=1)

model.learn(total_timesteps=200_000)
model.save("ppo_masked_drl_rm")
env.close()

if __name__ == '__main__':
mp.freeze_support()
main()

With these modifications, your framework is now equipped with action masking, a crucial technique for making the complex DRL-RM agent trainable.